Inspect

(Notebooke only) Interactive exploration of trained encoder and embeddings
import os
assert False == os.path.isdir('/app/data'), "Do not try to run this on solveit. The memory requirements will crash the VM."
import torch
from torch.utils.data import DataLoader
from omegaconf import OmegaConf
from midi_rae.vit import ViTEncoder, ViTDecoder
from midi_rae.swin import SwinEncoder, SwinDecoder
from midi_rae.data import PRPairDataset
from midi_rae.viz import make_emb_viz, viz_mae_recon
from midi_rae.utils import load_checkpoint
import matplotlib.pyplot as plt

# Interactive visualization (without wandb logging)
import plotly.io as pio
pio.renderers.default = 'notebook'
from midi_rae.viz import umap_project, pca_project, plot_embeddings_3d, make_emb_viz, viz_mae_recon

Config

#cfg = OmegaConf.load('../configs/config.yaml')
cfg = OmegaConf.load('../configs/config_swin.yaml')
#device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
device = 'cpu'  # leave GPU free for training while we do analysis here.
print(f'device = {device}')
device = cpu

Load Dataset

val_ds = PRPairDataset(image_dataset_dir=cfg.data.path, split='val', max_shift_x=cfg.training.max_shift_x, max_shift_y=cfg.training.max_shift_y)
val_dl = DataLoader(val_ds, batch_size=cfg.training.batch_size, num_workers=4, drop_last=True)
print(f'Loaded {len(val_ds)} validation samples, batch_size = {cfg.training.batch_size}')
Loading 91 val files from /home/shawley/datasets/POP909_images_basic...
Finished loading.
Loaded 9100 validation samples, batch_size = 380

Inspect Data

batch = next(iter(val_dl))
img1, img2, deltas, file_idx = batch['img1'].to(device), batch['img2'].to(device), batch['deltas'].to(device), batch['file_idx'].to(device)
print("img1.shape, deltas.shape, file_idx.shape =",tuple(img1.shape), tuple(deltas.shape), tuple(file_idx.shape))
img1.shape, deltas.shape, file_idx.shape = (380, 1, 128, 128) (380, 2) (380,)
# Show a sample image pair
idx = 0
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(img1[idx, 0].cpu(), cmap='gray')
axes[0].set_title(f'Image 1 (file_idx={file_idx[idx].item()})')
axes[1].imshow(img2[idx, 0].cpu(), cmap='gray')
axes[1].set_title(f'Image 2 (deltas = {deltas[idx].cpu().int().numpy()})')
plt.tight_layout()
plt.show()

Load Encoder from Checkpoint

# if cfg.model.get('encoder', 'vit') == 'swin':

# model = ViTEncoder(cfg.data.in_channels, (cfg.data.image_size, cfg.data.image_size), 
#                    cfg.model.patch_size, cfg.model.dim, cfg.model.depth, cfg.model.heads).to(device)

# ckpt_path = f'../checkpoints/{}__best.pt'  # <-- change as needed
# ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
# state_dict = {k.replace('_orig_mod.', ''): v for k, v in ckpt['model_state_dict'].items()}
# model.load_state_dict(state_dict, strict=False)

if cfg.model.get('encoder', 'vit') == 'swin':
    encoder = SwinEncoder(img_height=cfg.data.image_size, img_width=cfg.data.image_size,
                    patch_h=cfg.model.patch_h, patch_w=cfg.model.patch_w,
                    embed_dim=cfg.model.embed_dim, depths=cfg.model.depths,
                    num_heads=cfg.model.num_heads, window_size=cfg.model.window_size,
                    mlp_ratio=cfg.model.mlp_ratio, drop_path_rate=cfg.model.drop_path_rate).to(device)
else:
    encoder = ViTEncoder(cfg.data.in_channels, cfg.data.image_size, cfg.model.patch_size,
                         cfg.model.dim, cfg.model.depth, cfg.model.heads).to(device)
encoder = load_checkpoint(encoder, cfg.get('encoder_ckpt', f'../checkpoints/{encoder.__class__.__name__}__best.pt'))
encoder.eval()
print(f"Loaded {encoder.__class__.__name__}")
>>> Loaded model checkpoint from ../checkpoints/SwinEncoder__best.pt
Loaded SwinEncoder

Run Batch Through Encoder

with torch.no_grad():
    enc_out1 = encoder(img1)
    enc_out2 = encoder(img2)

#     z1 = enc_out1.patches.all_emb.reshape(-1, enc_out1.patches[1].dim)
#     z2 = enc_out2.patches.all_emb.reshape(-1, enc_out2.patches[1].dim)
#     num_tokens = enc_out1.patches.all_emb.shape[1]

# print(f'z1: {z1.shape}, z2: {z2.shape}, num_tokens: {num_tokens}')

Visualize Embeddings

NOTE: This will visualize all embeddings in the entire batch, not just the single pair of images shown above.

figs = make_emb_viz((enc_out1, enc_out2), encoder=encoder, batch=batch, do_umap=False)
figs.keys() # show what figures are available
dict_keys(['cls_pca_fig', 'cls_umap_fig', 'patch_pca_fig', 'patch_umap_fig', 'empty_pca_fig'])

Next code cell reads:

figs['cls_pca_fig'].show()

Make sure the next code cell is hidden or else the plotly.js will swamp the LLM context.

figs['cls_pca_fig'].show()

Note how the CLS tokens are nicely grouped in pairs. Let’s see if the same is true for the randomly-sampled pairs of non-empty patch embeddings 🤞:

Next code cell reads:

figs['patch_pca_fig'].show()

Make sure the next code cell is hidden or else the plotly.js will swamp the LLM context.

figs['patch_pca_fig'].show()

SVD Analysis

def svd_analysis(enc_out, level=1,  title='', top_k=20):
    "Run SVD on encoder output, plot singular value spectrum and cumulative variance"
    z = enc_out.patches[level].emb.detach().cpu().float().reshape(-1, enc_out.patches[level].dim)  # flatten batch
    z = z - z.mean(dim=0)  # center
    U, S, Vt = torch.linalg.svd(z, full_matrices=False) # Vt for "V transpose" (technically it's "V hermitian" but we've got real data)
    var_exp = (S**2) / (S**2).sum()
    cum_var = var_exp.cumsum(0)

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
    ax1.semilogy(S.numpy()); ax1.axvline(x=top_k, color='r', ls='--', alpha=0.5)
    ax1.set(xlabel='Component', ylabel='Singular value', title=f'{title} Singular Values')
    ax2.bar(range(top_k), var_exp[:top_k].numpy())
    ax2.set(xlabel='Component', ylabel='Variance explained', title=f'{title} Top {top_k} Variance')
    ax3.plot(cum_var.numpy()); ax3.axhline(y=0.9, color='r', ls='--', alpha=0.5, label='90%')
    ax3.set(xlabel='Component', ylabel='Cumulative variance', title=f'{title} Cumulative Variance')
    ax3.legend()
    plt.tight_layout(); plt.show()

    n90 = (cum_var < 0.9).sum().item() + 1
    print(f"{title}: {n90} components for 90% variance, top-1 explains {var_exp[0]:.1%}")
    return S, U, Vt, var_exp
S, U, Vt, var_exp = svd_analysis(enc_out2, title='Patches')

Patches: 7 components for 90% variance, top-1 explains 59.3%
cls_S, cls_U, cls_Vt, cls_var_exp = svd_analysis(enc_out2, level=0, title='CLS')

CLS: 2 components for 90% variance, top-1 explains 85.9%

Two key takeaways:

  1. Patches need 178/256 dims for 90%. The representation is highly distributed with no dominant direction. This means the encoder is using nearly all its capacity, which is healthy (no dimensional collapse). But it also suggests rhythm and pitch aren’t cleanly factored — if they were, you’d expect a sharper elbow in the spectrum (the first 1 or 2 components notwithstanding).
  2. CLS only needs 23/256 dims. The global summary is much more compressed. That’s interesting for generation: it suggests the “gist” of a musical passage lives in a ~23-dimensional subspace. The gradual slope in the top-20 bars (no single dominant component) means it’s not collapsing to a trivial representation either.

Decoder Performance

if cfg.model.get('encoder', 'vit') == 'swin': # decoder should match encoder
    decoder = SwinDecoder(img_height=cfg.data.image_size, img_width=cfg.data.image_size,
                        patch_h=cfg.model.patch_h, patch_w=cfg.model.patch_w,
                        embed_dim=cfg.model.embed_dim,
                        depths=cfg.model.get('dec_depths', cfg.model.depths), 
                        num_heads=cfg.model.get('dec_num_heads', cfg.model.num_heads)).to(device)
else: 
    decoder = ViTDecoder(cfg.data.in_channels, (cfg.data.image_size, cfg.data.image_size),
                     cfg.model.patch_size, cfg.model.dim, 
                     cfg.model.get('dec_depth', 4), cfg.model.get('dec_heads', 8)).to(device)

name = decoder.__class__.__name__
print("Name = ",name)
decoder = load_checkpoint(decoder, cfg.get('encoder_ckpt', f'../checkpoints/{decoder.__class__.__name__}__best.pt'))
Name =  SwinDecoder
>>> Loaded model checkpoint from ../checkpoints/SwinDecoder__best.pt
recon_logits = decoder(enc_out2)
img_recon = torch.sigmoid(recon_logits) 
img_real = img2
print("img_recon.shape, img_real.shape =",img_recon.shape, img_real.shape)
img_recon.shape, img_real.shape = torch.Size([380, 1, 128, 128]) torch.Size([380, 1, 128, 128])
grid_recon, grid_real, grid_map, evals = viz_mae_recon(img_recon, img_real, enc_out=None, epoch=0, debug=False, return_maps=True)

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 5))
ax1.imshow(grid_real.permute(1,2,0), cmap='gray'); ax1.set_title('Real')
ax2.imshow(grid_recon.permute(1,2,0), cmap='gray'); ax2.set_title('Recon')
ax3.imshow(grid_map.permute(1,2,0)); ax3.set_title('Map')
plt.show()
print(', '.join(f"{k}: {v.item():.4f}" for k, v in evals.items() if not k.endswith('map')))

precision: 0.9990, recall: 0.9998, specificity: 1.0000, f1: 0.9994

Wow! F1 = 0.9995!

That’s nearly perfect reconstruction: F1 accuracy of 99.95% Seems we have our representation autoencoder!

Let’s Show the map image really big. It’s designed to show red pixels wherever there are False Positives and yellow pixels wherever there are False Negatives (and white = True Pos, black = True Neg)…
I don’t see any red or yellow, do you?

In the next cell we’re gonna plot an image showing the maps as a very large image, we’re gonna hide it from the LLM because it doesn’t need to see it and we wanna spare the context.

from PIL import Image
from IPython.display import display

img = Image.fromarray((grid_map*255).permute(1,2,0).byte().numpy())
display(img)